# -*- coding: utf-8 -*-
"""
Created on Sun Jun  2 23:40:01 2024

@author: Yuhan Zhu
"""

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

from sklearn.utils import resample
from xml.etree import ElementTree as ET
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

import pandas as pd

from tensorflow.keras.applications import MobileNetV2


def parse_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    labels = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        labels.append((label, (xmin, ymin, xmax, ymax)))
    return labels


def crop_image(image, bbox, size1=128, size2=128):
    x_center = (bbox[0] + bbox[2]) // 2
    y_center = (bbox[1] + bbox[3]) // 2
    x_start = max(x_center - size1 // 2, 0)
    y_start = max(y_center - size2 // 2, 0)
    x_end = x_start + size1
    y_end = y_start + size2
    
    
    if x_end > image.shape[1]:
        x_end = image.shape[1]
        x_start = x_end - size1
    if y_end > image.shape[0]:
        y_end = image.shape[0]
        y_start = y_end - size2

    return image[y_start:y_end, x_start:x_end]


def random_crop(image, size1=128, size2=128):
    h, w, _ = image.shape
    x_start = np.random.randint(0, w - size1)
    y_start = np.random.randint(0, h - size2)
    return image[y_start:y_start + size2, x_start:x_start + size1]


def batch_crop_images(data_dir, label_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    for image_file in os.listdir(data_dir):
        image_path = os.path.join(data_dir, image_file)
        label_file = os.path.join(label_dir, os.path.splitext(image_file)[0] + '.xml')
        output_path = os.path.join(output_dir, image_file)

        if os.path.exists(label_file):
            labels_info = parse_xml(label_file)
            img = cv2.imread(image_path)

            if 'person' in list(map(lambda x: x[0], labels_info)):
                
                for label, bbox in labels_info:
                    if label == 'person':
                        cropped_img = crop_image(img, bbox, size1=128, size2=128)
                        break
            else:
                
                cropped_img = random_crop(img, size1=128, size2=128)
            
            cv2.imwrite(output_path, cropped_img)




def load_data(data_dir, label_dir):
    images = []
    labels = []
    for image_file in os.listdir(data_dir):
        image_path = os.path.join(data_dir, image_file)
        label_file = os.path.join(label_dir, os.path.splitext(image_file)[0] + '.xml')
        if os.path.exists(label_file):
            labels_info = parse_xml(label_file)
            img = cv2.imread(image_path)
            #img = cv2.resize(img, (500, 500))
            images.append((img/255).astype(np.float16))
            label = [0,0]
            if 'person' in list(map(lambda x: x[0], labels_info)):
                label = [1,0]
            else:
                label = [0,1]
            
            
            labels.append(label)
    return np.array(images), np.array(labels)

def visualize_data(images, labels, num_samples=10):
    fig, axes = plt.subplots(1, num_samples, figsize=(20, 20))
    indices = np.random.choice(range(len(images)), num_samples, replace=False)
    for i, idx in enumerate(indices):
        img = images[idx]
        if img.dtype != np.uint8:
            img = (img * 255).astype(np.uint8)
        
        axes[i].imshow(img)
        if labels.ndim == 2: 
            label_str = np.argmax(labels[idx])
        else: 
            label_str = labels[idx]
        axes[i].set_title(f"Label: {label_str}")
        axes[i].axis('off')
    
    plt.show()



def replace_weights(model, value_range):
    for layer in model.layers:
        if isinstance(layer, layers.Conv2D):
            weights, biases = layer.get_weights()
            for i in range(weights.shape[0]):
                for j in range(weights.shape[1]):
                    for k in range(weights.shape[2]):
                        for l in range(weights.shape[3]):
                            weights[i, j, k, l] = min(value_range, key=lambda x: abs(x - weights[i, j, k, l]))
            layer.set_weights([weights, biases])

# find the nearest value from replacement_weights
def find_nearest(array, value):
    idx = (np.abs(array - value)).argmin()
    return array[idx]


#(images, labels, output) polar_crop
batch_crop_images("F:/BFO CNN/PolarLITIS/PolarLITIS/train_polar/PARAM_POLAR/I04590", "F:/BFO CNN/PolarLITIS/PolarLITIS/train_polar/LABELS_polar", "F:/BFO CNN/PolarLITIS/PolarLITIS/train_polar/I04590_crop")
batch_crop_images("F:/BFO CNN/PolarLITIS/PolarLITIS/val_polar/PARAM_POLAR/I04590", "F:/BFO CNN/PolarLITIS/PolarLITIS/val_polar/LABELS_polar", "F:/BFO CNN/PolarLITIS/PolarLITIS/val_polar/I04590_crop")
batch_crop_images("F:/BFO CNN/PolarLITIS/PolarLITIS/test_polar/PARAM_POLAR/I04590", "F:/BFO CNN/PolarLITIS/PolarLITIS/test_polar/LABELS_polar", "F:/BFO CNN/PolarLITIS/PolarLITIS/test_polar/I04590_crop")

# #(images, labels, output) rgb_crop
# batch_crop_images("F:/BFO CNN/PolarLITIS/PolarLITIS/train_rgb/RGB_rs", "F:/BFO CNN/PolarLITIS/PolarLITIS/train_rgb/LABELS_RGB", "F:/BFO CNN/PolarLITIS/PolarLITIS/train_rgb/rgb_crop")
# batch_crop_images("F:/BFO CNN/PolarLITIS/PolarLITIS/val_rgb/RGB_rs", "F:/BFO CNN/PolarLITIS/PolarLITIS/val_rgb/LABELS_RGB", "F:/BFO CNN/PolarLITIS/PolarLITIS/val_rgb/rgb_crop")
# batch_crop_images("F:/BFO CNN/PolarLITIS/PolarLITIS/test_rgb/RGB", "F:/BFO CNN/PolarLITIS/PolarLITIS/test_rgb/LABELS_RGB", "F:/BFO CNN/PolarLITIS/PolarLITIS/test_rgb/rgb_crop")



# load data
# #RGB
# train_images, train_labels = load_data("F:/BFO CNN/PolarLITIS/PolarLITIS/train_rgb/rgb_crop", "F:/BFO CNN/PolarLITIS/PolarLITIS/train_rgb/LABELS_RGB")
# val_images, val_labels = load_data("F:/BFO CNN/PolarLITIS/PolarLITIS/val_rgb/rgb_crop", "F:/BFO CNN/PolarLITIS/PolarLITIS/val_rgb/LABELS_RGB")
# test_images, test_labels = load_data("F:/BFO CNN/PolarLITIS/PolarLITIS/test_rgb/rgb_crop", "F:/BFO CNN/PolarLITIS/PolarLITIS/test_rgb/LABELS_RGB")
#POLAR
train_images, train_labels = load_data("F:/BFO CNN/PolarLITIS/PolarLITIS/train_polar/I04590_crop", "F:/BFO CNN/PolarLITIS/PolarLITIS/train_polar/LABELS_polar")
val_images, val_labels = load_data("F:/BFO CNN/PolarLITIS/PolarLITIS/val_polar/I04590_crop", "F:/BFO CNN/PolarLITIS/PolarLITIS/val_polar/LABELS_polar")
test_images, test_labels = load_data("F:/BFO CNN/PolarLITIS/PolarLITIS/test_polar/I04590_crop", "F:/BFO CNN/PolarLITIS/PolarLITIS/test_polar/LABELS_polar")


# #data distribution test
# def plot_data_distribution(labels, title):
#     plt.hist(labels, bins=3, edgecolor='k')
#     plt.title(title)
#     plt.xlabel('Classes')
#     plt.ylabel('Frequency')
#     plt.show()

# plot_data_distribution(train_labels, 'Training Data Distribution')
# plot_data_distribution(val_labels, 'Validation Data Distribution')
# plot_data_distribution(test_labels, 'Test Data Distribution')


# After data inspection, it is found that the pictures without people (0) in the training set and the verification set are far more than the pictures with people (1)
# In the test set,  (0) has slightly fewer pictures than  (1)
# Accordingly, the data is balanced before training



def balance_data(images, labels):
    
    labels = np.array(labels)
    
    
    images_0 = images[np.all(labels == [1, 0], axis=1)] 
    labels_0 = labels[np.all(labels == [1, 0], axis=1)]
    images_1 = images[np.all(labels == [0, 1], axis=1)]
    labels_1 = labels[np.all(labels == [0, 1], axis=1)]

    
    images_0_upsampled, labels_0_upsampled = resample(images_0, labels_0, 
                                                      replace=True,     
                                                      n_samples=len(labels_1), 
                                                      random_state=42)

    
    images_balanced = np.vstack((images_1, images_0_upsampled))
    labels_balanced = np.vstack((labels_1, labels_0_upsampled))
    
    return images_balanced, labels_balanced


train_images_balanced, train_labels_balanced = balance_data(train_images, train_labels)
val_images_balanced, val_labels_balanced = balance_data(val_images, val_labels)


# plot_data_distribution(train_labels_balanced, 'Balanced Training Data Distribution')
# plot_data_distribution(val_labels_balanced, 'Balanced Validation Data Distribution')
# plot_data_distribution(test_labels, 'Test Data Distribution')

#Data visualization, random check to see if the label is wrong
visualize_data(train_images_balanced, train_labels_balanced, num_samples=10)

visualize_data(val_images_balanced, val_labels_balanced, num_samples=10)
visualize_data(test_images, test_labels, num_samples=10)




#########################################################################################
######### Load the bfo weight with the physical meaning A in I = A*sin(theta) ###########
#########################################################################################
file = pd.read_excel('weight.xlsx', header = None)
bfo_weight = np.array(file)
normalized_bfo_weight = (bfo_weight - min(bfo_weight))/(max(bfo_weight)-min(bfo_weight))
minus_weight = np.zeros((9,1))
for i in range(8):
    if normalized_bfo_weight[i] != 0:
        minus_weight[i] = -normalized_bfo_weight[i]
minus_weight = np.delete(minus_weight, -1)
normalized_bfo_weight = np.append(minus_weight, normalized_bfo_weight)




# Data preprocessing (data enhancement to improve generalization ability)
datagen = ImageDataGenerator(
    #rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    #fill_mode='nearest'
)


train_generator = datagen.flow(train_images_balanced, train_labels_balanced, batch_size=32)
# validation_generator = datagen.flow(val_images_balanced, val_labels_balanced, batch_size=32)



# Load the MobileNetV2 pre-trained model, excluding the top layer
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(128, 128, 3))


# Freeze the base model's convolutional layers
base_model.trainable = False


# for layer in base_model.layers[4:]:
#     layer.trainable = False
for layer in base_model.layers[:2]:
    layer.trainable = True

for layer in base_model.layers:
    print(f"{layer.name}: {layer.trainable}")



# Add custom top layers【PRE VERSION】
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(), 
    layers.Dense(16, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(2, activation='softmax')
])



model.compile(optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-5),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

# # pure simulation version
# history = model.fit(train_images_balanced, train_labels_balanced, epochs=200, batch_size=64,
#                     validation_data=(val_images_balanced, val_labels_balanced),
#                     callbacks=[TestMetricsCallback(), early_stopping])



# replacing version
epochs = 200
history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []}


model.evaluate(val_images_balanced, val_labels_balanced)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    history_epoch = model.fit(train_images_balanced, train_labels_balanced, epochs=1, batch_size=64,
                              validation_data=(val_images_balanced, val_labels_balanced))
    
    history['loss'].extend(history_epoch.history['loss'])
    history['accuracy'].extend(history_epoch.history['accuracy'])
    history['val_loss'].extend(history_epoch.history['val_loss'])
    history['val_accuracy'].extend(history_epoch.history['val_accuracy'])
    
    
    # weights replacement
    weights = base_model.get_layer('Conv1').get_weights()
    flat_weights = weights[0].flatten()
    new_flat_weights = np.vectorize(lambda x: find_nearest(normalized_bfo_weight, x))(flat_weights)
    new_kernel = new_flat_weights.reshape(weights[0].shape)
    base_model.get_layer('Conv1').set_weights([new_kernel])

   
    
    if epoch > 30 and min(history['val_loss'][-30:]) > min(history['val_loss']):
        print("Early stopping")
        break

np.savetxt('polar_val_accuracy.txt',history['val_accuracy'])
np.savetxt('polar_val_loss.txt', history['val_loss'])
np.savetxt('polar_train_accuracy.txt',history['accuracy'])
np.savetxt('polar_train_loss.txt',history['loss'])
# np.savetxt('polar_val_accuracy.txt', history.history['val_accuracy'])
# np.savetxt('polar_val_loss.txt', history.history['val_loss'])


plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], label='Train Accuracy')
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()


